-
Couldn't load subscription status.
- Fork 1.8k
[TRTLLM-8160][feat] Add draft token tree runtime on CDL #8586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TRTLLM-8160][feat] Add draft token tree runtime on CDL #8586
Conversation
b0c522d to
040bc9a
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #22532 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request introduces comprehensive tree-based speculative decoding enhancements across TensorRT-LLM. The changes refactor attention metadata interfaces, drafter model logic, and spec-decoding parameter handling to support dynamic tree-based token generation with explicit batch-level and resource management, including new buffer allocation strategies, tree-aware position tracking, and revised sampling paths. Changes
Sequence Diagram(s)sequenceDiagram
participant Engine as PyTorchModelEngine
participant AttMeta as AttentionMetadata
participant SpecMeta as Eagle3SpecMetadata
participant TreeMgr as SpecTreeManager
participant Drafter as ChainDrafter
rect rgb(240, 248, 255)
Note over Engine,Drafter: Tree-Based Speculative Decoding Flow (New)
end
Engine->>Engine: _prepare_tp_inputs(resource_manager)
Engine->>SpecMeta: Set request_accepted_path
Engine->>AttMeta: update_spec_dec_param(batch_size, spec_metadata, spec_tree_manager, ...)
AttMeta->>TreeMgr: Retrieve tree structure & buffers
alt Static Tree Path
TreeMgr->>AttMeta: Copy spec_dec_packed_mask, position_offsets
AttMeta->>AttMeta: Populate kv_lens_cuda, seq_lens
else Dynamic Tree Path
AttMeta->>AttMeta: Initialize placeholders for dynamic updates
end
Engine->>Drafter: forward() with spec_tree_manager
Drafter->>TreeMgr: get_generation_lengths(), get_masks(), get_offsets()
Drafter->>Drafter: sample(draft_layer_idx, logits, spec_tree_manager)
alt Tree Sampling
Drafter->>TreeMgr: Retrieve per-layer top_k_list
Drafter->>Drafter: Top-k sampling with tree constraints
else Linear Sampling
Drafter->>Drafter: Greedy/standard sampling
end
Drafter-->>Engine: Draft tokens with tree indices
Engine->>Engine: prepare_for_generation_with_tree_decoding()
Engine->>AttMeta: Update position_ids, masks, and indices per layer
Engine->>SpecMeta: Update gather indices and hidden-state read/write offsets
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Areas requiring extra attention:
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 23
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
773-799: Fix device/dtype mismatch in advanced indexing, convert scalar tensor to int, and ensure tensor-to-list conversion.Three bugs confirmed:
- Indexing dtype mismatch (lines 792–798):
eagle_pathsis int32; advanced indexing CPU tensors requires int64 indices.- Tensor-to-bool ambiguity (line 801):
cur_accepted_lenis a 0-D tensor; must convert to Python int before comparison.- Tensor-to-list assignment (line 825–826): Assigning int32 tensor slice to list slice stores tensor objects instead of integers; must call
.tolist().Apply the diff:
- all_draft_tokens = torch.tensor(request.py_draft_tokens) # [max_total_draft_tokens] - all_target_tokens = new_tokens_tensor[:, seq_slot, :].squeeze( - -1 - ) # [max_total_draft_tokens] + # Host-side CPU tensors, ensure long dtype for indexing + all_draft_tokens = torch.as_tensor(request.py_draft_tokens, dtype=torch.long, device="cpu") + all_target_tokens = new_tokens_tensor[:, seq_slot, :].squeeze(-1).to(dtype=torch.long, device="cpu") # [max_total_draft_tokens + 1] @@ - for path_idx, path in enumerate(eagle_paths): - path_exclude_root = ( - path[1:] - 1 - ) # [max_draft_len], '[1:]' since the new_tokens does not contain the root node. - # '-1' is the index shift after exclude the root node. - draft_tokens_indices = path_exclude_root[path_exclude_root >= 0] # [max_draft_len] - target_tokens_indices = path[path >= 0] # [max_draft_len + 1] + for path_idx, path in enumerate(eagle_paths): + # Convert to long for CPU advanced indexing + path_long = path.to(dtype=torch.long) + path_exclude_root = path_long[1:] - 1 # exclude root; -1 index shift + draft_tokens_indices = path_exclude_root[path_exclude_root >= 0] + target_tokens_indices = path_long[path_long >= 0] @@ - cur_draft_tokens = all_draft_tokens[draft_tokens_indices] - cur_target_tokens = all_target_tokens[target_tokens_indices] + cur_draft_tokens = all_draft_tokens.index_select(0, draft_tokens_indices) + cur_target_tokens = all_target_tokens.index_select(0, target_tokens_indices) @@ - cur_accepted_len = torch.cumprod( - (cur_draft_tokens == cur_target_tokens[:-1]).int(), dim=-1 - ).sum() - - # Accepted one more token from the target model. - cur_accepted_len += 1 - - if cur_accepted_len > longest_accepted_len: + cur_accepted_len = int(torch.cumprod( + (cur_draft_tokens == cur_target_tokens[:-1]).to(torch.int32), dim=-1 + ).sum().item()) + 1 # +1 accounts for root + + if cur_accepted_len > longest_accepted_len: longest_accepted_len = cur_accepted_len longest_match_path_idx = path_idx @@ - request.py_num_accepted_draft_tokens_indices[: num_accepted_draft_tokens - 1] = ( - eagle_paths[longest_match_path_idx][1:longest_accepted_len] - ) # exclude the root node + accepted_indices = eagle_paths[longest_match_path_idx][1:longest_accepted_len].tolist() + request.py_num_accepted_draft_tokens_indices[: num_accepted_draft_tokens - 1] = accepted_indices # exclude roottests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)
1-4: Missing NVIDIA Apache-2.0 header (2025)Per repo guidelines, prepend the standard header.
Apply at file start:
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +As per coding guidelines
tensorrt_llm/_torch/speculative/drafting_loops.py (1)
1-1: Missing NVIDIA Apache-2.0 headerAdd the required NVIDIA Apache-2.0 header (year 2025).
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1-1: Missing NVIDIA Apache-2.0 headerAdd the required NVIDIA Apache-2.0 header (year 2025).
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
1-1: Add required NVIDIA Apache-2.0 header (2025).File is missing the mandatory license header. Please prepend it.
Apply this diff:
+# Copyright (c) 2025, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
310-313: The dynamic path indexing bug is confirmed. The code attempts 3D indexing ([:, i, :]) on a 2D tensor (eagle_paths[tree_idx]is shape[max_total_draft_tokens + 1, max_draft_len + 1]), which causes a shape mismatch at assignment. The proposed fix correctly reshapes the nonzero indices to 1D and assigns them row-wise to the 2D tensor.
🧹 Nitpick comments (13)
tensorrt_llm/_torch/speculative/drafter.py (1)
67-67: Remove useless expression.self.max_total_draft_tokens is a no-op here. Drop it to satisfy linters and avoid confusion.
- self.max_total_draft_tokenstensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
584-585: Ensure contiguous int32 slice for KV lengths before passing to C++ op.Slicing returns a view; make it contiguous and int32 to match extension expectations.
- past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)] + past_key_value_lengths = ( + attn_metadata.kv_lens_cuda.narrow(0, 0, len(requests)).to(torch.int32).contiguous() + )Confirm torch.ops.tensorrt_llm.update_kv_cache_draft_token_location expects int32 on the same device as other KV tensors.
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (3)
15-25: Make DummyModel.forward fail fastUse explicit NotImplementedError to surface unintended calls during refactors.
class DummyModel(torch.nn.Module): @@ - def forward(self, *args, **kwargs) -> torch.Tensor: - pass + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("DummyModel.forward should not be called in this unit test")
54-60: Decouple from external model roots to keep unit test hermeticAvoid requiring llm_models_root for a path that is not used. Consider a benign default like os.environ.get("LLM_MODELS_ROOT", "/tmp") or pass a dummy path.
- spec_config = EagleDecodingConfig( + spec_config = EagleDecodingConfig( max_draft_len=max_draft_len, max_total_draft_tokens=max_total_draft_tokens, - speculative_model_dir=eagle_model_dir, + speculative_model_dir=os.environ.get("LLM_MODELS_ROOT", "/tmp"),
61-67: Assertion style and device selection nits
- Prefer torch.equal(output_tokens, ref_new_tokens) for clarity.
- Derive ref tensor device from logits.device to avoid hardcoding CUDA.
- assert torch.all(output_tokens == ref_new_tokens) + assert torch.equal(output_tokens, ref_new_tokens)And when constructing ref_new_tokens:
- ref_new_tokens = torch.tensor([...], device='cuda') + ref_new_tokens = torch.tensor([...], device=logits.device)tests/integration/defs/test_e2e.py (1)
2060-2093: Refactor eagle_choices string construction for clarity; remove memory-guard suggestionThe
--eagle_choicesflag is confirmed as supported inquickstart_advanced.py(type=str, default=None). However, refactor the eagle_choices JSON construction usingjson.dumps()for consistency with existing codebase patterns (e.g., test_e2e.py line 709) and to reduce manual JSON string errors.The memory-guard suggestion (skipif marker) is unnecessary—the test suite consistently validates memory requirements post-execution via
_check_mem_usage(), which is already present and correct in this test (_check_mem_usage(running_log, [27, 0, 0, 0])).tensorrt_llm/_torch/speculative/eagle3.py (2)
174-178: Ensure paired iterables are same lengthAdd an assertion before the loop to guarantee
request_idsandseq_lenshave equal length (useful under Py3.8 wherezip(strict=...)is unavailable).@@ if not self.is_draft_model: - for req_id, seq_len in zip(self.request_ids, self.seq_lens): + assert len(self.request_ids) == len(self.seq_lens), \ + "request_ids and seq_lens must be the same length" + for req_id, seq_len in zip(self.request_ids, self.seq_lens):
197-201: Replace fullwidth parenthesis in commentUse ASCII
)to avoid lint failures (RUF003).- # 2)is_first_draft + # 2) is_first_drafttests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py (3)
17-17: Avoid mutating sys.path in testsThis path hack is brittle in CI. Prefer relying on the test runner’s import paths.
-sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +# Avoid mutating sys.path; rely on test runner configuration.
6-8: Remove unused model path plumbing
llm_models_root()/eagle_model_dirare not needed; pass a dummy string toEagleDecodingConfigto decouple from local assets.-from utils.llm_data import llm_models_root @@ - models_path = llm_models_root() - eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" # It will not actually be used. + eagle_model_dir = "unused"Also applies to: 22-24
662-663: Unnecessary unittest entrypointThis is a pytest-style function test.
unittest.main()won’t discover it; safe to drop to avoid confusion.-if __name__ == "__main__": - unittest.main() +# Intentionally no unittest entrypoint; use pytest discovery.tensorrt_llm/_torch/speculative/drafting_loops.py (1)
145-147: Typo in comment“toshift” → “to shift”.
- 1] - 1 # shape: [next_layer_gen_len_per_req]. -1 is toshift the root node + 1] - 1 # shape: [next_layer_gen_len_per_req]. -1 is to shift the root nodetensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
324-364: Optional: simplify packed-mask computation with bitshifts.Avoid pow on int tensors and repeated reshape rebinds; use bit operations for clarity and speed.
Apply this diff:
- num_blocks = math.ceil((self.max_total_draft_tokens + 1) / 32) - int_tensor = mask_matrix.reshape( - -1, num_process_tokens - ) # shape: [num_trees * num_process_tokens, num_process_tokens] - packed_mask = packed_mask.reshape( - -1, - num_blocks) # shape: [num_trees * num_process_tokens, num_blocks] - - for block_idx in range(num_blocks): - start_idx = block_idx * 32 - end_idx = min(start_idx + 32, num_process_tokens) - if end_idx < start_idx: - break - block_bits = int_tensor[:, start_idx:end_idx] - weight = torch.pow( - 2, - torch.arange(end_idx - start_idx, - dtype=torch.int32, - device=int_tensor.device)) - block_value = torch.sum(block_bits * weight, dim=-1) - packed_mask[:, block_idx] = block_value - - packed_mask = packed_mask.reshape(num_trees, num_process_tokens, - num_blocks) + num_blocks = math.ceil((self.max_total_draft_tokens + 1) / 32) + rows = mask_matrix.reshape(-1, num_process_tokens) + out = packed_mask.reshape(-1, num_blocks) + for block_idx in range(num_blocks): + start = block_idx * 32 + end = min(start + 32, num_process_tokens) + if end <= start: + break + span = end - start + weights = (torch.ones(span, dtype=torch.int32, device=rows.device) << torch.arange(span, dtype=torch.int32, device=rows.device)) + out[:, block_idx] = (rows[:, start:end].to(torch.int32) * weights).sum(dim=-1)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
cpp/tensorrt_llm/thop/attentionOp.cpp(2 hunks)tensorrt_llm/_torch/attention_backend/interface.py(1 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(18 hunks)tensorrt_llm/_torch/pyexecutor/resource_manager.py(1 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(3 hunks)tensorrt_llm/_torch/speculative/drafter.py(1 hunks)tensorrt_llm/_torch/speculative/drafting_loops.py(3 hunks)tensorrt_llm/_torch/speculative/eagle3.py(4 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/model_drafter.py(3 hunks)tensorrt_llm/_torch/speculative/spec_tree_manager.py(7 hunks)tests/integration/defs/test_e2e.py(1 hunks)tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py(1 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py(10 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/speculative/drafter.pytensorrt_llm/_torch/speculative/interface.pytensorrt_llm/_torch/pyexecutor/resource_manager.pytests/unittest/_torch/speculative/test_draft_token_tree_sampling.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/pyexecutor/sampler.pytests/integration/defs/test_e2e.pytensorrt_llm/_torch/speculative/spec_tree_manager.pycpp/tensorrt_llm/thop/attentionOp.cpptensorrt_llm/_torch/speculative/drafting_loops.pytensorrt_llm/_torch/speculative/model_drafter.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/speculative/eagle3.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/speculative/drafter.pytensorrt_llm/_torch/speculative/interface.pytensorrt_llm/_torch/pyexecutor/resource_manager.pytests/unittest/_torch/speculative/test_draft_token_tree_sampling.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/pyexecutor/sampler.pytests/integration/defs/test_e2e.pytensorrt_llm/_torch/speculative/spec_tree_manager.pytensorrt_llm/_torch/speculative/drafting_loops.pytensorrt_llm/_torch/speculative/model_drafter.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/speculative/drafter.pytensorrt_llm/_torch/speculative/interface.pytensorrt_llm/_torch/pyexecutor/resource_manager.pytests/unittest/_torch/speculative/test_draft_token_tree_sampling.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/pyexecutor/sampler.pytests/integration/defs/test_e2e.pytensorrt_llm/_torch/speculative/spec_tree_manager.pycpp/tensorrt_llm/thop/attentionOp.cpptensorrt_llm/_torch/speculative/drafting_loops.pytensorrt_llm/_torch/speculative/model_drafter.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/speculative/eagle3.py
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...
Files:
cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.
Files:
cpp/tensorrt_llm/thop/attentionOp.cpp
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.
Files:
cpp/tensorrt_llm/thop/attentionOp.cpp
🧠 Learnings (1)
📚 Learning: 2025-08-20T06:56:02.889Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:577-579
Timestamp: 2025-08-20T06:56:02.889Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, maxSequenceLength is now enforced as a non-optional argument in the BlockManager constructor, so concerns about std::nullopt defaulting to 0 are not applicable. When windowSize > maxSequenceLength, a warning should be added instead of handling optional parameter cases.
Applied to files:
cpp/tensorrt_llm/thop/attentionOp.cpp
🧬 Code graph analysis (11)
tensorrt_llm/_torch/speculative/interface.py (1)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
TrtllmAttention(1172-1609)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
attn_metadata(124-125)
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (2)
tensorrt_llm/_torch/speculative/drafting_loops.py (3)
ChainDrafter(289-476)forward(300-430)sample(432-469)tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
SpecTreeManager(7-395)
tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py (3)
tensorrt_llm/_torch/speculative/drafting_loops.py (1)
prepare_for_generation_with_tree_decoding(110-286)tensorrt_llm/_torch/speculative/eagle3.py (2)
Eagle3ResourceManager(23-109)Eagle3SpecMetadata(113-266)tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
SpecTreeManager(7-395)
tensorrt_llm/_torch/attention_backend/interface.py (2)
cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h (1)
batch_size(167-167)tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
spec_metadata(116-117)
tests/integration/defs/test_e2e.py (1)
tests/integration/defs/conftest.py (3)
llm_root(192-193)llm_venv(702-719)llm_models_root(80-94)
tensorrt_llm/_torch/speculative/drafting_loops.py (4)
tensorrt_llm/_torch/attention_backend/interface.py (10)
AttentionMetadata(43-347)num_seqs(249-253)seq_lens(171-172)seq_lens(175-196)seq_lens_cuda(219-220)on_update(158-168)num_contexts(199-200)num_contexts(203-206)num_tokens(271-272)forward(605-628)tensorrt_llm/_torch/speculative/eagle3.py (2)
Eagle3SpecMetadata(113-266)forward(362-484)tensorrt_llm/_torch/speculative/interface.py (1)
SpecMetadata(168-256)tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
SpecTreeManager(7-395)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
tensorrt_llm/runtime/generation.py (1)
max_draft_tokens(1319-1322)
tensorrt_llm/_torch/pyexecutor/model_engine.py (6)
tensorrt_llm/_torch/speculative/eagle3.py (2)
Eagle3ResourceManager(23-109)Eagle3SpecMetadata(113-266)tensorrt_llm/_torch/pyexecutor/resource_manager.py (3)
ResourceManager(1203-1246)get_resource_manager(1215-1216)KVCacheManager(151-1154)tensorrt_llm/llmapi/llm_args.py (1)
is_linear_tree(646-649)tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (2)
spec_metadata(116-117)attn_metadata(124-125)tensorrt_llm/_torch/pyexecutor/scheduler.py (2)
ScheduledRequests(20-41)batch_size(37-38)tensorrt_llm/_torch/attention_backend/interface.py (1)
AttentionMetadata(43-347)
tensorrt_llm/_torch/attention_backend/trtllm.py (3)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
spec_metadata(116-117)tensorrt_llm/_utils.py (1)
get_sm_version(732-734)tensorrt_llm/_torch/speculative/interface.py (1)
is_eagle3(38-39)
tensorrt_llm/_torch/speculative/eagle3.py (2)
tensorrt_llm/_torch/attention_backend/interface.py (2)
seq_lens(171-172)seq_lens(175-196)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
get_slot(1164-1165)
🪛 Ruff (0.14.1)
tensorrt_llm/_torch/speculative/drafter.py
67-67: Found useless expression. Either assign it to a variable or remove it.
(B018)
tensorrt_llm/_torch/speculative/spec_tree_manager.py
16-16: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
226-227: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
tensorrt_llm/_torch/attention_backend/trtllm.py
1062-1062: Found useless expression. Either assign it to a variable or remove it.
(B018)
1063-1063: Found useless expression. Either assign it to a variable or remove it.
(B018)
1064-1064: Found useless expression. Either assign it to a variable or remove it.
(B018)
1118-1118: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
tensorrt_llm/_torch/speculative/eagle3.py
186-186: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
197-197: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
214-214: Consider [0, *accepted_path] instead of concatenation
(RUF005)
🔇 Additional comments (2)
tensorrt_llm/_torch/speculative/model_drafter.py (2)
576-581: Using max_total_draft_tokens for static tree is correctLooping to max_total_draft_tokens aligns with tree semantics and buffer sizes.
170-173: Confirmed: field is properly initialized on all pathsVerification shows
py_num_accepted_draft_tokens_indicesis initialized inLlmRequest.__init__()at line 485 asself.py_num_accepted_draft_tokens_indices = []. This initialization applies to all instance creation paths:
- Direct instantiation:
LlmRequest(request_id=...)goes through__init__- Child copy:
LlmRequest(llm_request=child)also goes through__init__Since
_create_draft_request()creates new requests via the constructor, all instances get the field initialized. The assignment at line 172 safely copies from the source request, which is guaranteed to have the field initialized. NoAttributeErrorrisk exists.
tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
Show resolved
Hide resolved
tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
Show resolved
Hide resolved
|
PR_Github #22532 [ run ] completed with state |
b17a837 to
655723b
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #22626 [ run ] triggered by Bot. Commit: |
|
PR_Github #22626 [ run ] completed with state |
|
/bot run |
|
PR_Github #22749 [ run ] triggered by Bot. Commit: |
|
PR_Github #22749 [ run ] completed with state |
| ) # exclude the root node | ||
| return num_accepted_draft_tokens - 1 | ||
|
|
||
| def _tree_sampling_batch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ixlmar , sorry, I know you've made a lot of improvements to batched sampling.
But I decided to remove this function because we're currently implementing the draft token tree only in capturable drafting loops (CDLs) (which may has better performance). The corresponding tree sampling will only appear in the sample() function in drafting_loops.py (this approach is somewhat like a one-model).
Although we have also implemented the draft token tree for non-CDL (draft PR), and this version of the drafter requires calling _tree_sampling_batch() after each forward pass. However, I currently have no plans to merge it.
I could also keep this function for future use, but I'm not sure if it would introduce additional maintenance burden. So I'd like to hear your thoughts.
cc @mikeiovine
| model_outputs = { | ||
| "logits": logits, | ||
| } | ||
| # Create the chain drafter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we moved tree_sampling to sample() in drafting_loops.py, we need to modify these tests accordingly.
| spec_dec_position_offsets: Optional[torch.Tensor] = None | ||
|
|
||
| # TODO: Optimized together with the subsequent dynamic tree. | ||
| # Auxiliary buffers for the static tree. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a lot of auxiliary variables for static trees. This is because the structure of static trees is fixed during inference, so we can reduce a lot of repeated calculations, and this will make the update logic simpler (for example, for packed mask/position offset, etc.).
| accepted_draft_token_offsets, packed_accepted_draft_tokens_indices, rewind_draft_token_separate_adjustments = self.locate_accepted_draft_tokens( | ||
| requests) | ||
| past_key_value_lengths = attn_metadata.kv_lens_cuda | ||
| past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we only need slices len(requests) of data, otherwise there will be an error that the shape does not match.
|
/bot run |
|
PR_Github #22770 [ run ] triggered by Bot. Commit: |
|
PR_Github #22770 [ run ] completed with state |
Signed-off-by: Yue Weng <[email protected]>
Signed-off-by: Yue Weng <[email protected]>
Signed-off-by: Yue Weng <[email protected]>
Signed-off-by: Yue Weng <[email protected]>
60c0caf to
feb398c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #22831 [ run ] triggered by Bot. Commit: |
|
PR_Github #22831 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Refactor
Tests
Description
In this PR, we implemented the runtime logic for the draft token tree. Given the improved performance of capturable drafting loops (CDL), our implementation is also based on CDL. A non-CDL draft PR is available here, but it's not considered for merging: #8109
With this PR, we now have the following features:
Unverified:
Required tests for this PR before merging:
[x] Verify that the current implementation is compatible with CUDA Graph
[x] Verify that this PR does not impact other existing functionality
For detailed changes, please refer to the image below:
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.